{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.insert(0, \"../../../../../Projects/Others/GitHub/UPerNet/\")\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "from utils import *\n",
    "import cv2\n",
    "import numpy as np\n",
    "from dataloader import CitySpaceDataset\n",
    "from dataloader import Transforms\n",
    "import random"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_hyp= {\n",
    "  'data_aug_shear': 0.0,\n",
    "  'data_aug_translate': 0.0,\n",
    "  'data_aug_degree': 0,\n",
    "  \"data_aug_perspective_p\": 0., \n",
    "  'data_aug_prespective': 0.0,\n",
    "  'data_aug_hsv_p': 0.0,\n",
    "  'data_aug_hsv_hgain': 0.0,\n",
    "  'data_aug_hsv_sgain': 0.,\n",
    "  'data_aug_hsv_vgain': 0.,\n",
    "  'data_aug_fliplr_p': 0.,\n",
    "  'data_aug_flipud_p': 0.0,\n",
    "  'data_aug_fill_value': 114,\n",
    "  'data_aug_cutout_p': 0.0, \n",
    "  \"data_aug_blur_p\": 0., \n",
    "  \"data_aug_saturation_p\": 0., \n",
    "  \"data_aug_crop_p\": 0., \n",
    "  \"data_aug_brightness_p\": 0.}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "trans = Transforms(data_hyp)\n",
    "dataset_kwargs = {\"img_dir\": \"../../../../../Dataset/Segmentation/cityscapes/image/train/\", \n",
    "                  \"seg_dir\": \"../../../../../Dataset/Segmentation/cityscapes/label/train/\", \n",
    "                  \"img_size\": [768, 768], \n",
    "                  \"enable_data_aug\": True, \n",
    "                  \"transform\":trans\n",
    "                }\n",
    "\n",
    "dataset = CitySpaceDataset(**dataset_kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(1024, 2048, 3) (1024, 2048, 1)\n"
     ]
    }
   ],
   "source": [
    "img, seg, i = dataset[9]\n",
    "print(img.shape, seg.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def debug_plot(img, seg, show_label=14):\n",
    "    fig, axes = plt.subplots(1, 3, figsize=[16, 6])\n",
    "    axes[0].imshow(img[..., ::-1])\n",
    "    seg_tmp = np.zeros_like(seg)\n",
    "    seg_tmp[seg == show_label] = 255\n",
    "    axes[1].imshow(seg_tmp, cmap=\"gray\")\n",
    "    mask = np.zeros_like(img)\n",
    "    mask[..., 0] = seg_tmp[..., 0]\n",
    "    mask[..., 1] = seg_tmp[..., 0]\n",
    "    mask[..., 2] = seg_tmp[..., 0]\n",
    "    mask = mask / np.max(mask)\n",
    "    masked_img = img.copy().astype(np.float32)\n",
    "    masked_img *= mask\n",
    "    masked_img = np.clip(masked_img, 0, 255).astype(np.uint8)\n",
    "    axes[2].imshow(masked_img[..., ::-1])\n",
    "    fig.suptitle(f\"label: {show_label}\")\n",
    "    fig.tight_layout()\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rnd_i = np.random.randint(0, 1000)\n",
    "img, seg, i = dataset[rnd_i]\n",
    "print(img.shape, seg.shape)\n",
    "for lab in np.unique(seg):\n",
    "    if lab != 0.0:\n",
    "        debug_plot(img, seg, lab)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from dataloader import build_dataloader\n",
    "\n",
    "kwargs = {\"img_dir\": \"../../../../../Dataset/Segmentation/cityscapes/image/train/\", \n",
    "          \"seg_dir\": \"../../../../../Dataset/Segmentation/cityscapes/label/train/\", \n",
    "          \"data_aug_hyp\": data_hyp, \n",
    "          \"enable_data_aug\": False, \n",
    "          \"dst_size\": 768}\n",
    "\n",
    "dataset, dataloader, prefetcher = build_dataloader(**kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = next(dataloader)\n",
    "\n",
    "# for data in dataloader:\n",
    "imgs = data['img']\n",
    "segs = data['seg']\n",
    "ids = data['id']\n",
    "print(imgs.shape, segs.shape, len(ids))\n",
    "img = np.clip(imgs.permute(0, 2, 3, 1)[1].numpy() * 255, 0, 255).astype(np.uint8)\n",
    "debug_plot(img, segs.permute(0, 2, 3, 1)[1].numpy())\n",
    "# break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "img_np = imgs.permute(0, 2, 3, 1).contiguous().numpy()\n",
    "seg_np = segs.permute(0, 2, 3, 1).contiguous().numpy()\n",
    "for i in range(len(imgs)):\n",
    "    img = img_np[i]\n",
    "    img = img * 255\n",
    "    img = np.clip(img, 0, 255).astype(np.uint8)\n",
    "    seg = seg_np[i]\n",
    "    print(ids[i])\n",
    "    debug_plot(img, seg)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.15 ('torch1.13')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.15"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "a00c0a220c6a4fa84f8be1eccff78cca396213150bbb3b308447c2ee397323f8"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
